# build module
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import logging
LOGGER = logging.getLogger(__name__)
LOGGER.addHandler(logging.NullHandler())


class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.weight = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.weight.data.normal_(mean=0.0, std=0.05)

        self.bias = nn.Parameter(torch.Tensor(hidden_size))
        b = np.zeros(hidden_size, dtype=np.float32)
        self.bias.data.copy_(torch.from_numpy(b))

        self.query = nn.Parameter(torch.Tensor(hidden_size))
        self.query.data.normal_(mean=0.0, std=0.05)

    def forward(self, batch_hidden, batch_masks):
        # batch_hidden: b x len x hidden_size (2 * hidden_size of lstm)
        # batch_masks:  b x len

        # linear
        key = torch.matmul(batch_hidden, self.weight) + self.bias  # b x len x hidden

        # compute attention
        outputs = torch.matmul(key, self.query)  # b x len

        masked_outputs = outputs.masked_fill((1 - batch_masks).bool(), float(-1e32)) #  batch_masks取反后填充-1e32，即原来mask为0的值现在变成-1e32

        #  经过softmax计算attention得分
        attn_scores = F.softmax(masked_outputs, dim=1)  # b x len

        # 将attention得分再mask一下，对于全零向量，-1e32的结果为 1/len, -inf为nan, 额外补0
        masked_attn_scores = attn_scores.masked_fill((1 - batch_masks).bool(), 0.0) #  b x len

        # sum weighted sources
        batch_outputs = torch.bmm(masked_attn_scores.unsqueeze(1), key).squeeze(1)  # b x hidden

        return batch_outputs, attn_scores


# build word encoder
word2vec_path = '../emb/word2vec.txt'
dropout = 0.15
word_hidden_size = 128
word_num_layers = 2


class WordLSTMEncoder(nn.Module):
    def __init__(self, vocab):
        super(WordLSTMEncoder, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.word_dims = 100

        self.word_embed = nn.Embedding(vocab.word_size, self.word_dims, padding_idx=0)

        extword_embed = vocab.load_pretrained_embs(word2vec_path)
        extword_size, word_dims = extword_embed.shape
        logging.info("Load extword embed: words %d, dims %d." % (extword_size, word_dims))

        self.extword_embed = nn.Embedding(extword_size, word_dims, padding_idx=0)
        self.extword_embed.weight.data.copy_(torch.from_numpy(extword_embed))
        self.extword_embed.weight.requires_grad = False

        input_size = self.word_dims

        self.word_lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=word_hidden_size,
            num_layers=word_num_layers,
            batch_first=True,
            bidirectional=True
        )

    def forward(self, word_ids, extword_ids, batch_masks):
        # word_ids: sen_num x sent_len
        # extword_ids: sen_num x sent_len
        # batch_masks   sen_num x sent_len

        word_embed = self.word_embed(word_ids)  # sen_num x sent_len x 100
        extword_embed = self.extword_embed(extword_ids)
        batch_embed = word_embed + extword_embed

        if self.training:
            batch_embed = self.dropout(batch_embed)

        hiddens, _ = self.word_lstm(batch_embed)  # sen_num x sent_len x  hidden*2
        hiddens = hiddens * batch_masks.unsqueeze(2)

        if self.training:
            hiddens = self.dropout(hiddens)

        return hiddens


# build sent encoder
sent_hidden_size = 256
sent_num_layers = 2


class SentEncoder(nn.Module):
    def __init__(self, sent_rep_size):
        super(SentEncoder, self).__init__()
        self.dropout = nn.Dropout(dropout)

        self.sent_lstm = nn.LSTM(
            input_size=sent_rep_size,
            hidden_size=sent_hidden_size,
            num_layers=sent_num_layers,
            batch_first=True,
            bidirectional=True
        )

    def forward(self, sent_reps, sent_masks):
        # sent_reps:  b x doc_len x sent_rep_size
        # sent_masks: b x doc_len

        sent_hiddens, _ = self.sent_lstm(sent_reps)  # b x doc_len x hidden*2
        sent_hiddens = sent_hiddens * sent_masks.unsqueeze(2)

        if self.training:
            sent_hiddens = self.dropout(sent_hiddens)

        return sent_hiddens


# build model
class Model(nn.Module):
    def __init__(self, vocab):
        super(Model, self).__init__()
        self.sent_rep_size = word_hidden_size * 2
        self.doc_rep_size = sent_hidden_size * 2
        self.all_parameters = {}
        parameters = []
        self.word_encoder = WordLSTMEncoder(vocab)
        self.word_attention = Attention(self.sent_rep_size)
        parameters.extend(list(filter(lambda p: p.requires_grad, self.word_encoder.parameters())))
        parameters.extend(list(filter(lambda p: p.requires_grad, self.word_attention.parameters())))

        self.sent_encoder = SentEncoder(self.sent_rep_size)
        self.sent_attention = Attention(self.doc_rep_size)
        parameters.extend(list(filter(lambda p: p.requires_grad, self.sent_encoder.parameters())))
        parameters.extend(list(filter(lambda p: p.requires_grad, self.sent_attention.parameters())))

        self.out = nn.Linear(self.doc_rep_size, vocab.label_size, bias=True)
        parameters.extend(list(filter(lambda p: p.requires_grad, self.out.parameters())))

        if len(parameters) > 0:
            self.all_parameters["basic_parameters"] = parameters

        logging.info('Build model with lstm word encoder, lstm sent encoder.')

        para_num = sum([np.prod(list(p.size())) for p in self.parameters()])
        logging.info('Model param num: %.2f M.' % (para_num / 1e6))

    def forward(self, batch_inputs):
        # batch_inputs(batch_inputs1, batch_inputs2): b x doc_len x sent_len
        # batch_masks : b x doc_len x sent_len
        batch_inputs1, batch_inputs2, batch_masks = batch_inputs

        # 先将b x doc_len x sent_len([batch_size, 每个文档有多少个句子，每个句子有多少个单词])压缩成sen_num x sent_len([该batch包含的句子总数，每个句子多少个单词])
        batch_size, max_doc_len, max_sent_len = batch_inputs1.shape[0], batch_inputs1.shape[1], batch_inputs1.shape[2]
        batch_inputs1 = batch_inputs1.view(batch_size * max_doc_len, max_sent_len)  # sen_num x sent_len
        batch_inputs2 = batch_inputs2.view(batch_size * max_doc_len, max_sent_len)  # sen_num x sent_len
        batch_masks = batch_masks.view(batch_size * max_doc_len, max_sent_len)  # sen_num x sent_len

        # 将sen_num x sent_len([该batch包含的句子总数，每个句子多少个单词])放入LSTM，获取每个单词的表征向量，注意，此时的batch_masks还是单词颗粒度的mask
        batch_hiddens = self.word_encoder(batch_inputs1, batch_inputs2,
                                          batch_masks)  # sen_num x sent_len x sent_rep_size

        # 将sen_num x sent_len x sent_rep_size([该batch包含的句子总数，每个句子多少个单词, 每个单词的表征向量长度])放ATTENTION中，获取该句子的表征向量
        sent_reps, atten_scores = self.word_attention(batch_hiddens, batch_masks)  # sen_num x sent_rep_size

        # 将sen_num x sent_len x sent_rep_size([该batch包含的句子总数，每个句子多少个单词, 每个单词的表征向量长度])展开成b x doc_len x sent_rep_size([batch_size, 每个文档多少个句子，每个句子的表征向量长度])
        sent_reps = sent_reps.view(batch_size, max_doc_len, self.sent_rep_size)  # b x doc_len x sent_rep_size
        batch_masks = batch_masks.view(batch_size, max_doc_len, max_sent_len)  # b x doc_len x max_sent_len
        sent_masks = batch_masks.bool().any(2).float()  # b x doc_len #  axis=2是单词颗粒度的mask，做any表示但凡该句中但凡有一个mask值为1，则该句子的mask为1，即转换成句子颗粒度的mask

        # 将b x doc_len x sent_rep_size([batch_size, 每个文档多少个句子，每个句子的表征向量长度])放入LSTM，获取每个句子的表征向量，注意，此时的sent_masks是句子颗粒度的mask
        sent_hiddens = self.sent_encoder(sent_reps, sent_masks)  # b x doc_len x doc_rep_size
        # 将b x doc_len x sent_rep_size([batch_size, 每个文档多少个句子，每个句子的表征向量长度])放入ATTENTION，获取该文档的表征向量
        doc_reps, atten_scores = self.sent_attention(sent_hiddens, sent_masks)  # b x doc_rep_size

        # 将b x doc_rep_size([batch_size, 每个文档的表征向量长度])放入NN，输出预测概率
        batch_outputs = self.out(doc_reps)  # b x num_labels

        return batch_outputs
